%% implementation and simulation of Alg. 1
%
% Input: idxtrue        index corresponding to true dynamics
%        (Atot,Btot)    two structs containing (A,B) system matrices for
%                       all candidate models
%        Q,R            defines stage cost as x' Q x + u' R u
%        Kt             feedback gain corresponding to dynamics (At,Bt)
%        T              time horizon / simulation time
%
% Output: cost          stage cost over horizon (length T)
%         cost_opt      stage cost of policy Kt over horizon (length T)
%                       with same process noise realization as cost
%         trajectories  struct with state trajectory, input trajectory,
%                       applied policies, state/input trajectory
%                       corresponding to Kt
%         estimates     struct with trajectory of probabilities over models
%                       and selected candidate models
%
%
function [cost,cost_opt,trajectories,estimates]=runExperimentExperts(idxtrue,Atot,Btot,Ktot,Q,R,T)

% extract parameters
no=length(Atot);    % number of candidate models
n=size(Q,1);        % state dimension
nu=size(R,1);       % input dimension

x1=zeros(n,1);      % initial condition

% initialize weights
wk=zeros(no,T);
wk(:,1)=ones(no,1);

% initialize losses (accumulated one-step prediction error for each
% candidate model)
lk=zeros(no,T);

% initialize distribution over models, initialized to uniform
pk=zeros(no,T);
pk(:,1)=ones(no,1)/no;

% initialize parameters
eta=10;         % learning rate
sigma=1;        % process noise
M=5;            % persistance of excitation parameter
sigma_u=sqrt(10/(M*eta*nu)*(2+log(2*no))); % initial excitation

% initialize vector of chosen candidate models
ik=zeros(1,T);
ik(1)=sampleFun(pk(:,1));   % sample from initial distribution

% initialize vector of excitation
sigma_uk=zeros(1,T);
sigma_uk(1)=sigma_u;

% initialize trajectories (state/input)
xtraj=zeros(n,T);
xtraj(:,1)=x1;
utraj=zeros(nu,T);
utraj(:,1)=randn(nu,1)*sigma_uk(1);

xtraj_opt=zeros(n,T);
xtraj_opt(:,1)=x1;

% initialize cost vectors
cost=zeros(1,T);
cost(1)=x1'*Q*x1+utraj(:,1)'*R*utraj(:,1);
cost_opt=zeros(1,T);
cost_opt(1)=x1'*Q*x1+x1'*Ktot{idxtrue}'*R*Ktot{idxtrue}*x1;

for k=2:T
    % generate process noise
    nk=randn(n,1);
    
    % simulate dynamics
    xtraj_opt(:,k)=(Atot{idxtrue}+Btot{idxtrue}*Ktot{idxtrue})*xtraj_opt(:,k-1)+sigma*randn(n,1);%sigma*nk;
    xtraj(:,k)=Atot{idxtrue}*xtraj(:,k-1)+Btot{idxtrue}*utraj(:,k-1)+sigma*nk;    
    
    % compute probability distribution over models ------------------------
    %
    % 1. keep track of prediction error of different experts
    for kk=1:no % sum over all experts
        xpn=Atot{kk}*xtraj(:,k-1)+Btot{kk}*utraj(:,k-1); % prediction of expert kk
        wn=xtraj(:,k)-xpn;                               % prediction error
        
        % loss for the multiplicative weight update
        lk(kk,k)=wn'*wn;
    end
    %
    % 2. update corresponding weights
    wk(:,k)=exp(-eta*sum(lk(:,1:k),2)); % update weights 
                    % [can also be done incrementally but is numerically 
                    % less stable (as "exp" values might get very small)]
    %
    % 3. update probability distribution over models
    %
    if sum(wk(:,k))==0      % if else is used for numerical stability 
                            % (accumulated loss might get large)
        [~,idxmin]=min(sum(lk(:,1:k),2));
        pk(idxmin,k)=1;
    else
        pk(:,k)=wk(:,k)/sum(wk(:,k));
    end
        

    % determine uk (input) by sampling from pk
    if mod(k-1,M)==0
        % sample
        out=sampleFun(pk(:,k));
        
        % store result
        ik(k)=out;

        % update excitation
        sigma_u=sqrt(1/(M*eta*nu)*(2/ceil(k/M)+log(2*no)/ceil(k/M)^2));
    else
        ik(k)=ik(k-1);
    end
    
    % store excitation
    sigma_uk(k)=sigma_u;

    % compute/apply feedback policy
    utraj(:,k)=Ktot{ik(k)}*xtraj(:,k)+randn(nu,1)*sigma_uk(k);
    
    % compute stage costs
    cost(:,k)=xtraj(:,k)'*Q*xtraj(:,k)+utraj(:,k)'*R*utraj(:,k);
    cost_opt(:,k)=xtraj_opt(:,k)'*Q*xtraj_opt(:,k)+xtraj_opt(:,k)'*Ktot{idxtrue}'*R*Ktot{idxtrue}*xtraj_opt(:,k);
end

% store the estimates and trajectories in corresponding structs
estimates=struct('pk',pk,'ik',ik);
trajectories=struct('xtraj',xtraj,'utraj',utraj,'xtraj_opt',xtraj_opt,'sigma_uk',sigma_uk);

end


%% sample from categorical distribution
%
% input:   p    vector defining probabilities (normalized to one)
% output:  out  random sample
%
function out=sampleFun(p)
u=rand(1);

k=1;
psum=p(k);
while psum<u && k<=length(p)-1
    k=k+1;
    psum=psum+p(k);
end

out=k;
end
